#!/usr/bin/env python3
"""
Run evaluations over GMO campaign datasets.

This script processes JSONL files where each line contains a 'prompt'
with in-context examples and a 'response' with the ground-truth score.
It runs in a chunk-focused way to support parallel execution.
"""
# --------------------------------------------------------------------------------------
# Imports & Config
# --------------------------------------------------------------------------------------
import argparse
import json
import os
import re
import sys
import random
from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
# from openai import AzureOpenAI  # Removed – switching to local Qwen model

# --------------------------------------------------------------------------------------
# External model / API imports
# --------------------------------------------------------------------------------------
# Switched from local HF model to Azure OpenAI GPT for inference
# --------------------------------------------------------------------------------------
# Local Qwen model will be used; Azure OpenAI config no longer needed
# --------------------------------------------------------------------------------------

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

# --------------------------------------------------------------------------------------
# Single system prompt (no persona)
# --------------------------------------------------------------------------------------

SYSTEM_PROMPT = """You are a master social-media strategist. You will receive an ad input and several example tweets.

Your objectives:
1.  Synthesize the key insights from the examples.
2.  Generate FOUR distinct, high-quality tweets that cater to the ad input. Each tweet should explore a slightly different angle or tone.
3.  Keep each tweet ≤ 280 characters.
4.  Ensure the tweets are professional, engaging, and use hashtags effectively.

Return ONLY a JSON-formatted list of 4 strings, where each string is a generated tweet.
Example format:
[
    "This is the first tweet, leveraging a direct and bold tone. #Innovation",
    "Here's a second option, focusing more on the emotional and community aspect. #Together",
    "A third tweet, using a question to drive engagement. What do you think? #Future",
    "The fourth and final tweet, more professional and corporate in style. #Official"
]

Do not include any other text, analysis, or commentary in your response."""

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(full_prompt: str, *_, **__) -> str:
    """Generate tweet text using the local Qwen model."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": "/no_think" + full_prompt},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False,
    )
    input_ids = input_ids.to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=300,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1,
        )

    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run GMO evaluation over a slice of a campaign dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="gmo_results", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    # --tweet_eval is kept for compatibility with the run.sh script, but could be removed
    # if the runner is also updated to remove it.
    parser.add_argument("--tweet_eval", action="store_true", help="Legacy flag to trigger this evaluation path.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Model once at the start
    # -----------------------------------------------------------------------------
    # Load local Qwen model (4-bit for memory efficiency)
    model_name = "Qwen/Qwen3-32B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # The script now only performs one task, so we can call it directly.
    run_gmo_evaluation(args)


def run_gmo_evaluation(args):
    """
    End-to-end evaluation on GMO datasets containing {"prompt":..., "response":...} per line.
    The 'prompt' field contains ICL examples and the final query.
    """
    global model, tokenizer
    # Resolve dataset paths
    if args.dataset_paths:
        dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]
    else:
        # This path should not be taken if run via run.sh
        print("[ERROR] --dataset_paths is a required argument.", file=sys.stderr)
        sys.exit(1)

    overall_out_dir = args.output_dir
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}"
        # Use a more specific output filename to avoid clashes
        out_path = os.path.join(overall_out_dir, f"gmo_results_{dataset_name}{slice_suffix}.json")

        # ------------------------------------------------------------------
        # Build few-shot examples (randomly choose up to 5 from current slice)
        # ------------------------------------------------------------------
        random.seed(42)  # reproducibility across runs
        num_few_shot = min(5, len(records))
        sampled_examples = random.sample(records, num_few_shot)
        few_shot_lines = [
            f"Input: {ex.get('prompt', '')}\nTweet: {ex.get('response', '').replace('<hyperlink>', '').strip()}"
            for ex in sampled_examples
        ]
        few_shot_context = "You are given 5 examples. Study them carefully and learn their format. Then generate a tweet that follows this pattern.\n\n" + "\n\n".join(few_shot_lines)

        all_results = []

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            # The 'prompt' field from the dataset contains the media description and context.
            ad_prompt = rec.get("prompt", "")
            gt_resp = rec.get("response", None)  # ground-truth tweet (if provided)

            combined_prompt = f"{few_shot_context}\n\nNow you are given a new tweet to create.\nInput: {ad_prompt}\nTweet:"

            # Print the full prompt for transparency/debugging
            print("\n[MODEL PROMPT]\n" + combined_prompt + "\n")

            resp_text = verbalize(combined_prompt, model, tokenizer, args)

            # Store generated tweet for later evaluation (e.g., ROUGE / BLEU)
            all_results.append({
                "prompt": ad_prompt,
                "ground_truth": gt_resp,
                "generated": resp_text,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        print(f"[INFO] Completed processing {dataset_name} slice {slice_start}-{slice_end}. Results saved to {out_path}")

    print("\n[INFO] GMO ad evaluation complete.")

if __name__ == "__main__":
    main()


            

            

        
        
        

        
